import torch

from .base import BasePruner
from .utils import check_sparsity, find_layers


class UnstructuredPruner(BasePruner):
    def __init__(
        self,
        scores="weight",
        sparsity_ratio=0.0,
        n_samples=0,
        seed=0,
        dataset_name="c4",
        **kwargs
    ):
        super().__init__(
            scores, sparsity_ratio, n_samples, seed, dataset_name, **kwargs
        )

    def prune(self, model, tokenizer, device):
        if self.W_metrics is None:
            self.calculate_scores(model, tokenizer, device)

        try:
            layers = model.model.layers
        except:
            layers = model.model.decoder.layers
        cnt = 0
        for i in range(len(layers)):
            layer = layers[i]
            subset = find_layers(layer)

            for name in subset:
                W_metric = self.W_metrics[cnt]
                W_mask = torch.zeros_like(W_metric) == 1
                sort_res = torch.sort(W_metric, dim=-1, stable=True)
                indices = sort_res[1][:, : int(W_metric.shape[1] * self.sparsity_ratio)]
                W_mask.scatter_(1, indices, True)
                del self.W_metrics[cnt]
                cnt += 1
                subset[name].weight.data[W_mask] = 0

        torch.cuda.empty_cache()

        return model, check_sparsity(model)


def get(**kwargs):
    return UnstructuredPruner(**kwargs)
